Data Unzip and Count¶

In [ ]:
# import zipfile
# import os

# def unzip_file(zip_path, extract_to_folder):
#     # Ensure the output folder exists
#     os.makedirs(extract_to_folder, exist_ok=True)

#     # Unzip the file
#     with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#         zip_ref.extractall(extract_to_folder)
#     print(f"Extracted '{zip_path}' to '{extract_to_folder}'")

# zip_file_path = "midis.zip"
# output_folder_path = "data"

# unzip_file(zip_file_path, output_folder_path)
In [ ]:
# import os

# def count_files_in_folder(folder_path):
#     count = sum(
#         1 for entry in os.scandir(folder_path) if entry.is_file()
#     )
#     print(f"Number of files in '{folder_path}': {count}")
#     return count

# folder_path = "data/midis"
# count_files_in_folder(folder_path)

Context of the Dataset¶

The dataset used for this project is GiantMIDI-Piano, a large-scale symbolic music dataset specifically tailored for classical piano music analysis and generation. It contains 10,855 MIDI files corresponding to 2,786 composers. The MIDI files are derived from live piano performances available on YouTube, which were then transcribed using a high-resolution, deep learning-based transcription system. This system is capable of capturing subtle musical elements such as dynamics, timing variations, and pedal usage - making the dataset particularly useful for expressive music generation tasks.

MIDI Creation¶

The MIDI files in the GiantMIDI-Piano dataset were not manually created, nor did we transcribe them ourselves. Instead, they were generated by the original dataset authors using a high-resolution automatic transcription pipeline. The source audio consisted of solo classical piano performances collected from YouTube, selected based on a curated metadata file that included composer names and piece titles. These audio recordings were then processed through a deep learning-based transcription system, designed to convert complex polyphonic piano audio into symbolic MIDI format. This system accurately detects note onsets, offsets, pitches, and velocities, capturing expressive performance elements such as timing and dynamics. The full transcription process, which spanned approximately 200 hours on a single GPU, produced over 10,000 MIDI files.

Data Analysis - Code Overview¶

  1. Loads all MIDI files from a specified directory
  2. Extracts the following features:

    a. Filename

    b. Number of notes

    c. Average pitch

    d. Min/max pitch

    e. Total duration

    f. Polyphony (estimated by note start overlap)

    g. Composer (from filename)

In [ ]:
import pretty_midi
import os
import pandas as pd

def extract_midi_features(file_path):
    try:
        midi_data = pretty_midi.PrettyMIDI(file_path)
        all_notes = [note for instrument in midi_data.instruments for note in instrument.notes if not instrument.is_drum]
        
        if not all_notes:
            return None
        
        pitches = [note.pitch for note in all_notes]
        start_times = [note.start for note in all_notes]
        durations = [note.end - note.start for note in all_notes]
        
        polyphony = len(set(start_times)) / midi_data.get_end_time() if midi_data.get_end_time() > 0 else 0
        
        filename = os.path.basename(file_path)
        composer = filename.split(',')[0].strip() if ',' in filename else "Unknown"

        return {
            'filename': filename,
            'composer': composer,
            'duration_sec': midi_data.get_end_time(),
            'n_notes': len(all_notes),
            'avg_pitch': sum(pitches) / len(pitches),
            'min_pitch': min(pitches),
            'max_pitch': max(pitches),
            'avg_duration': sum(durations) / len(durations),
            'polyphony_score': polyphony
        }
    
    except Exception as e:
        print(f"Failed to process {file_path}: {e}")
        return None

midi_dir = "data/midis"

feature_list = []
for file in os.listdir(midi_dir):
    if file.endswith('.mid'):
        path = os.path.join(midi_dir, file)
        features = extract_midi_features(path)
        if features:
            feature_list.append(features)

midi_df = pd.DataFrame(feature_list)

midi_df.head()
Out[ ]:
filename composer duration_sec n_notes avg_pitch min_pitch max_pitch avg_duration polyphony_score
0 Kirchner, Theodor, 8 Romances, Op.22, m90Rf9AY... Kirchner 124.684896 1159 58.195858 24 91 0.655047 8.782138
1 Jacobi, Karl, Introduction and Polonaise, Op.9... Jacobi 506.805990 10213 68.113287 24 94 0.208727 17.588585
2 Chopin, Frédéric, Nocturne in C minor, B.1... Chopin 216.692708 860 65.243023 31 101 1.508937 3.908761
3 Wieniawski, Józef, Valse-caprice, Op.46, 3BV... Wieniawski 502.740885 4006 65.908637 22 100 0.631758 7.270147
4 Beethoven, Ludwig van, 12 Variations on the Ru... Beethoven 683.522135 6161 65.257101 26 90 0.359308 8.457663
In [ ]:
print(midi_df.columns)
Index(['filename', 'composer', 'duration_sec', 'n_notes', 'avg_pitch',
       'min_pitch', 'max_pitch', 'avg_duration', 'polyphony_score'],
      dtype='object')
In [ ]:
csv_path = "midi_features_head.csv"
midi_df.to_csv(csv_path, index=False)
In [ ]:
import pandas
midi_df = pandas.read_csv('midi_features_head.csv')
# print(midi_df)
midi_df.head()
Out[ ]:
filename composer duration_sec n_notes avg_pitch min_pitch max_pitch avg_duration polyphony_score
0 Kirchner, Theodor, 8 Romances, Op.22, m90Rf9AY... Kirchner 124.684896 1159 58.195858 24 91 0.655047 8.782138
1 Jacobi, Karl, Introduction and Polonaise, Op.9... Jacobi 506.805990 10213 68.113287 24 94 0.208727 17.588585
2 Chopin, Frédéric, Nocturne in C minor, B.1... Chopin 216.692708 860 65.243023 31 101 1.508937 3.908761
3 Wieniawski, Józef, Valse-caprice, Op.46, 3BV... Wieniawski 502.740885 4006 65.908637 22 100 0.631758 7.270147
4 Beethoven, Ludwig van, 12 Variations on the Ru... Beethoven 683.522135 6161 65.257101 26 90 0.359308 8.457663

Loading the Metadata¶

In [ ]:
import pandas
import pandas as pd

metadata_df = pd.read_csv(
    'full_music_pieces_youtube_similarity_pianosoloprob_split.csv',
    delimiter='\t',
    quotechar='"',
    on_bad_lines='skip'
)

# metadata_df.head()

# print(midi_df)
metadata_df.head()
Out[ ]:
surname firstname music nationality birth death youtube_title youtube_id similarity piano_solo_prob audio_name audio_duration giant_midi_piano split surname_in_youtube_title
0 A. Jag Je t'aime Juliette unknown unknown unknown Je t'aime Juliette - A. Jag OXC7Fd0ZN8o 1.000000 6.848339e-01 A., Jag, Je t'aime Juliette, OXC7Fd0ZN8o 69.553469 1.0 validation 1.0
1 Aadler C. A. Floating Islands unknown unknown unknown Mind-Boggling Off-Grid FLOATING Island HOMESTEAD wPhWfjyqCBs 0.333333 NaN NaN NaN NaN NaN NaN
2 Aagesen Truid Cantiones trium vocum Danish 1500 1600 2nd Edition, Motecta Trium Vocum iWROE7EzwlE 0.500000 NaN NaN NaN NaN NaN NaN
3 Aaron Michael Piano Course unknown unknown unknown Michael Aaron Piano Course Lessons Grade 1 Com... V8WvKK-1b2c 1.000000 7.859141e-01 Aaron, Michael, Piano Course, V8WvKK-1b2c 1556.569469 1.0 validation 1.0
4 Aarons Alfred E. Brother Bill unknown unknown unknown Brother Bill Giet2Krl6Ww 0.666667 7.822375e-07 Aarons, Alfred E., Brother Bill, Giet2Krl6Ww 181.333469 0.0 NaN 0.0
  1. surname - Last name of the composer
  2. firstname - First name of the composer
  3. music - Title of the musical piece
  4. nationality - Nationality of the composer
  5. birth - Composer’s year of birth (may be "unknown")
  6. death - Composer’s year of death (may be "unknown")
  7. youtube_title - Title of the original YouTube video
  8. youtube_id - Unique ID of the YouTube video
  9. similarity - Title similarity score between metadata and YouTube title
  10. piano_solo_prob - Model-estimated probability that the piece is a solo piano performance
  11. audio_name - Filename-like ID used for audio and MIDI alignment
  12. audio_duration - Duration of the audio in seconds
  13. giant_midi_piano - Flag indicating inclusion in the final GiantMIDI-Piano dataset (1 = included)
  14. split - Dataset split (train/validation/test) for experimental use
  15. surname_in_youtube_title - Whether the composer’s surname appears in the YouTube video title (1 = yes)

Exploratory Data Analysis (EDA) based on extracted features¶

In [ ]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid")
plt.rcParams["figure.figsize"] = (10, 6)

1. Histogram: Distribution of Piece Durations¶

In [ ]:
plt.figure()
sns.histplot(midi_df['duration_sec'], bins=50, kde=True)
plt.title("Distribution of Piece Durations")
plt.xlabel("Duration (seconds)")
plt.ylabel("Number of Pieces")
plt.show()

2. Histogram: Number of Notes Per Piece¶

In [ ]:
plt.figure()
sns.histplot(midi_df['n_notes'], bins=50, kde=True, color='orange')
plt.title("Distribution of Number of Notes Per Piece")
plt.xlabel("Number of Notes")
plt.ylabel("Number of Pieces")
plt.show()

3. Bar Plot: Top 10 Most Frequent Composers¶

In [ ]:
plt.figure()
top_composers = midi_df['composer'].value_counts().nlargest(10)
sns.barplot(x=top_composers.values, y=top_composers.index, palette="mako")
plt.title("Top 10 Most Frequent Composers")
plt.xlabel("Number of Pieces")
plt.ylabel("Composer")
plt.show()
/tmp/ipykernel_439/2009332229.py:4: FutureWarning: 

Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(x=top_composers.values, y=top_composers.index, palette="mako")

4. Scatter Plot: Duration vs. Number of Notes¶

In [ ]:
plt.figure()
sns.scatterplot(data=midi_df, x='duration_sec', y='n_notes', alpha=0.6)
plt.title("Duration vs. Number of Notes")
plt.xlabel("Duration (seconds)")
plt.ylabel("Number of Notes")
plt.show()

5. Scatter Plot: Average Pitch vs. Polyphony¶

In [ ]:
plt.figure()
sns.scatterplot(data=midi_df, x='avg_pitch', y='polyphony_score', alpha=0.6)
plt.title("Average Pitch vs. Polyphony Score")
plt.xlabel("Average Pitch (MIDI number)")
plt.ylabel("Polyphony Score")
plt.show()

Combining the Extracted Features with Metadata CSV¶

In [ ]:
midi_df['base_name'] = midi_df['filename'].str.replace('.mid', '', regex=False)
merged_df = midi_df.merge(metadata_df, left_on='base_name', right_on='audio_name', how='left')
In [ ]:
filtered_df = merged_df[
    (merged_df['giant_midi_piano'] == 1) &
    (merged_df['piano_solo_prob'] > 0.5) &
    (merged_df['surname_in_youtube_title'] == 1)
]

1. Top 10 Most Common Composer Nationalities¶

In [ ]:
top_nations = filtered_df['nationality'].value_counts().nlargest(10)

filtered_nations_df = filtered_df[filtered_df['nationality'].isin(top_nations.index)]

plt.figure(figsize=(10, 6))
sns.countplot(data=filtered_nations_df, y='nationality', order=top_nations.index)
plt.title("Top 10 Most Common Composer Nationalities")
plt.xlabel("Number of Pieces")
plt.ylabel("Nationality")
plt.show()
In [ ]:
filtered_df = filtered_df.copy()

filtered_df['birth'] = pd.to_numeric(filtered_df['birth'], errors='coerce')
filtered_df['death'] = pd.to_numeric(filtered_df['death'], errors='coerce')
filtered_df['century'] = (filtered_df['birth'] // 100 + 1).fillna("Unknown")

2. Distribution of Composer Centuries¶

In [ ]:
filtered_df['century'] = filtered_df['century'].astype(str)
plt.figure(figsize=(10, 6))
sns.countplot(data=filtered_df, x='century', order=sorted(filtered_df['century'].unique()))
plt.title("Distribution of Composer Centuries")
plt.xlabel("Century")
plt.ylabel("Number of Pieces")
plt.show()

3. Probability of Piano being played Solo¶

In [ ]:
sns.histplot(filtered_df['piano_solo_prob'], bins=20)
Out[ ]:
<Axes: xlabel='piano_solo_prob', ylabel='Count'>

4. Correlation Between Musical Features¶

In [ ]:
plt.figure(figsize=(10, 6))
sns.heatmap(filtered_df[['duration_sec', 'n_notes', 'avg_pitch', 'polyphony_score']].corr(), annot=True, cmap="coolwarm")
plt.title("Correlation Between Musical Features")
plt.show()

Modeling + Evaluation: Task1: Code Walkthrough (Symbolic, unconditional generation)¶

Baseline Model: Task1: (Symbolic, unconditional generation)¶

In [ ]:
# Baseline RNN Model
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from miditoolkit import MidiFile
from glob import glob
import numpy as np
import random
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from miditok import REMI, TokenizerConfig
import math
from collections import Counter
from miditok import TokSequence

random.seed(42)

main_dir = "data/midis"

all_midi_paths = [os.path.join(main_dir, f) for f in os.listdir(main_dir) if f.endswith(".mid")]

random.shuffle(all_midi_paths)
train_paths = all_midi_paths[:100]
val_paths = all_midi_paths[100:125]
test_paths = all_midi_paths[125:150]

config = TokenizerConfig(num_velocities=16, use_chords=False, use_programs=False)
tokenizer = REMI(config)
tokenizer.train(vocab_size=1000, files_paths=train_paths)
print("Tokenizer complete")

class MusicDataset(Dataset):
    def __init__(self, midi_paths, tokenizer, seq_len=256):
        self.data = []
        self.token_counts = Counter()
        self.total_tokens = 0
        self.seq_len = seq_len

        for path in midi_paths:
            try:
                midi = MidiFile(path)
                token_seq = tokenizer(midi)[0].tokens
                token_ids = [tokenizer[token] for token in token_seq if token in tokenizer.vocab]

                self.token_counts.update(token_ids)
                self.total_tokens += len(token_ids)

                for i in range(0, len(token_ids) - seq_len, seq_len):
                    self.data.append((token_ids[i:i+seq_len], token_ids[i+1:i+seq_len+1]))
            except Exception as e:
                print(f"[ERROR] {path}: {e}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, y = self.data[idx]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

train_dataset = MusicDataset(train_paths, tokenizer)
val_dataset = MusicDataset(val_paths, tokenizer)
test_dataset = MusicDataset(test_paths, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=True)

print("All dataloaders ready")


Tokenizer complete
/tmp/ipykernel_1630/3806623913.py:50: UserWarning: You are using a depreciated `miditoolkit.MidiFile` object. MidiTokis now (>v3.0.0) using symusic.Score as MIDI backend. Your file willbe converted on the fly, however please consider using symusic.
  token_seq = tokenizer(midi)[0].tokens
All dataloaders ready
In [ ]:
class RNNModel(torch.nn.Module):
    def __init__(self, vocab_size, embed_size=256, hidden_size=512, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.RNN(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, x, hidden=None):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        return self.fc(out), hidden  
In [ ]:
import matplotlib.pyplot as plt

def plot_losses(epoch_losses):
    train_losses, val_losses = zip(*epoch_losses)
    epochs = range(1, len(train_losses) + 1)

    plt.plot(epochs, train_losses, label="Train Loss")
    plt.plot(epochs, val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.show()
In [ ]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_model(model, train_loader, val_loader, epochs=10, lr=1e-3, ckpt_path="best_model_rnn.pt"):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    best_val_loss = float("inf")
    epoch_losses = []

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        avg_train_loss = total_loss / len(train_loader)
        val_loss = evaluate_loss(model, val_loader)
        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), ckpt_path)
            print(" Checkpointed best model.")

        epoch_losses.append((avg_train_loss, val_loss))

    return epoch_losses

def evaluate_loss(model, dataloader):
    model.eval()
    total_loss = 0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
            total_loss += loss.item()
    return total_loss / len(dataloader)

def generate(model, start_token_id, max_len=1000):
    model.eval()
    idxs = [start_token_id]
    inp = torch.tensor([[start_token_id]], device=device)
    hidden = None
    for _ in range(max_len - 1):
        logits, hidden = model(inp, hidden) if hasattr(model, "rnn") else (model(inp), None)
        next_token = torch.multinomial(torch.softmax(logits[0, -1], dim=-1), 1).item()
        idxs.append(next_token)
        inp = torch.tensor([[next_token]], device=device)
    return idxs

def save_generated(ids, out_file, tokenizer):
    if not hasattr(tokenizer, 'vocab_inv'):
        tokenizer.vocab_inv = {v: k for k, v in tokenizer.vocab.items()}
    tokens = [tokenizer.vocab_inv[i] for i in ids]
    tok_sequence = TokSequence(tokens=tokens)
    midi = tokenizer.decode([tok_sequence])
    midi.dump_midi(out_file)
    print(f" Saved generated MIDI to: {out_file}")
In [ ]:
vocab_size = len(tokenizer.vocab)
model = RNNModel(vocab_size)
losses = train_model(model, train_loader, val_loader, epochs=20, lr=1e-3)
plot_losses(losses)

model.load_state_dict(torch.load("best_model_rnn.pt"))

start_token = tokenizer.vocab["Bar_None"]
generated_ids = generate(model, start_token, max_len=256)

save_generated(generated_ids, "sample_output_test_rnn.mid", tokenizer)
Epoch 1/20: 100%|██████████| 59/59 [00:02<00:00, 24.34it/s]
Epoch 1 | Train Loss: 3.4076 | Val Loss: 3.0708
 Checkpointed best model.
Epoch 2/20: 100%|██████████| 59/59 [00:02<00:00, 27.93it/s]
Epoch 2 | Train Loss: 3.0016 | Val Loss: 2.9523
 Checkpointed best model.
Epoch 3/20: 100%|██████████| 59/59 [00:02<00:00, 27.92it/s]
Epoch 3 | Train Loss: 2.8642 | Val Loss: 2.8086
 Checkpointed best model.
Epoch 4/20: 100%|██████████| 59/59 [00:02<00:00, 27.89it/s]
Epoch 4 | Train Loss: 2.7517 | Val Loss: 2.7418
 Checkpointed best model.
Epoch 5/20: 100%|██████████| 59/59 [00:02<00:00, 27.97it/s]
Epoch 5 | Train Loss: 2.6997 | Val Loss: 2.7032
 Checkpointed best model.
Epoch 6/20: 100%|██████████| 59/59 [00:02<00:00, 28.10it/s]
Epoch 6 | Train Loss: 2.6623 | Val Loss: 2.6779
 Checkpointed best model.
Epoch 7/20: 100%|██████████| 59/59 [00:02<00:00, 28.12it/s]
Epoch 7 | Train Loss: 2.6249 | Val Loss: 2.6529
 Checkpointed best model.
Epoch 8/20: 100%|██████████| 59/59 [00:02<00:00, 28.10it/s]
Epoch 8 | Train Loss: 2.6002 | Val Loss: 2.6391
 Checkpointed best model.
Epoch 9/20: 100%|██████████| 59/59 [00:02<00:00, 28.09it/s]
Epoch 9 | Train Loss: 2.5716 | Val Loss: 2.6227
 Checkpointed best model.
Epoch 10/20: 100%|██████████| 59/59 [00:02<00:00, 28.06it/s]
Epoch 10 | Train Loss: 2.5515 | Val Loss: 2.6138
 Checkpointed best model.
Epoch 11/20: 100%|██████████| 59/59 [00:02<00:00, 28.06it/s]
Epoch 11 | Train Loss: 2.5282 | Val Loss: 2.5981
 Checkpointed best model.
Epoch 12/20: 100%|██████████| 59/59 [00:02<00:00, 28.01it/s]
Epoch 12 | Train Loss: 2.5119 | Val Loss: 2.6032
Epoch 13/20: 100%|██████████| 59/59 [00:02<00:00, 27.90it/s]
Epoch 13 | Train Loss: 2.4995 | Val Loss: 2.5937
 Checkpointed best model.
Epoch 14/20: 100%|██████████| 59/59 [00:02<00:00, 27.96it/s]
Epoch 14 | Train Loss: 2.4810 | Val Loss: 2.5918
 Checkpointed best model.
Epoch 15/20: 100%|██████████| 59/59 [00:02<00:00, 27.94it/s]
Epoch 15 | Train Loss: 2.4663 | Val Loss: 2.5962
Epoch 16/20: 100%|██████████| 59/59 [00:02<00:00, 27.95it/s]
Epoch 16 | Train Loss: 2.4485 | Val Loss: 2.5818
 Checkpointed best model.
Epoch 17/20: 100%|██████████| 59/59 [00:02<00:00, 27.99it/s]
Epoch 17 | Train Loss: 2.4371 | Val Loss: 2.5817
 Checkpointed best model.
Epoch 18/20: 100%|██████████| 59/59 [00:02<00:00, 27.98it/s]
Epoch 18 | Train Loss: 2.4254 | Val Loss: 2.5843
Epoch 19/20: 100%|██████████| 59/59 [00:02<00:00, 27.98it/s]
Epoch 19 | Train Loss: 2.4067 | Val Loss: 2.5819
Epoch 20/20: 100%|██████████| 59/59 [00:02<00:00, 27.97it/s]
Epoch 20 | Train Loss: 2.3982 | Val Loss: 2.5915
 Saved generated MIDI to: sample_output_test_rnn.mid
In [ ]:
import pretty_midi
import numpy as np
from IPython.display import Audio

def midi_to_audio(midi_path, sample_rate=44100):

    pm = pretty_midi.PrettyMIDI(midi_path)   
    audio = pm.synthesize(fs=sample_rate)
    
    return Audio(audio, rate=sample_rate)

midi_to_audio("sample_output_test_rnn.mid")
Out[ ]:
Your browser does not support the audio element.

Task1: Main Model: LSTM¶

In [ ]:
import os
import glob
import torch
import pretty_midi
import numpy as np
import random
from torch import nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from IPython.display import Audio, display
In [ ]:
class MIDIPreprocessor:
    def __init__(self, sequence_length=100):
        self.sequence_length = sequence_length
        self.pitch_range = (21, 109)
        self.num_pitches = self.pitch_range[1] - self.pitch_range[0] + 1
        print(f"Preprocessor initialized: pitch range {self.pitch_range}, sequence length {self.sequence_length}")

    def encode_midi(self, midi_path):
        # print(f"Processing MIDI file: {midi_path}")
        try:
            midi = pretty_midi.PrettyMIDI(midi_path)
            notes = []
            for instrument in midi.instruments:
                if not instrument.is_drum:
                    for note in instrument.notes:
                        if self.pitch_range[0] <= note.pitch <= self.pitch_range[1]:
                            notes.append(note.pitch - self.pitch_range[0])
            # print(f"Extracted {len(notes)} notes")
            return notes
        except Exception as e:
            print(f"Error processing {midi_path}: {e}")
            return []

    def build_sequences(self, all_notes):
        print(f"Building note sequences...")
        sequences = []
        for i in range(len(all_notes) - self.sequence_length):
            seq = all_notes[i:i + self.sequence_length + 1]
            sequences.append(seq)
        print(f"Built {len(sequences)} sequences")
        return sequences
In [ ]:
class MIDIDataset(Dataset):
    def __init__(self, midi_dir, preprocessor, file_limit=100):
        print(f"Loading MIDI dataset from: {midi_dir}")
        self.preprocessor = preprocessor
        all_notes = []

        midi_files = glob.glob(os.path.join(midi_dir, "*.mid"))[:file_limit]
        # print(f"Using first {len(midi_files)} MIDI files.")
        if len(midi_files) == 0:
            raise ValueError("No MIDI files found.")

        for i, file in enumerate(midi_files):
            # print(f"[{i+1}/{len(midi_files)}] Processing: {file}")
            notes = preprocessor.encode_midi(file)
            all_notes += notes

        self.notes = all_notes
        self.sequences = self.preprocessor.build_sequences(self.notes)

        if len(self.sequences) == 0:
            raise ValueError("No sequences built.")
        print(f"Dataset ready: {len(self.sequences)} sequences")

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        input_seq = torch.tensor(sequence[:-1], dtype=torch.long)
        target_seq = torch.tensor(sequence[1:], dtype=torch.long)
        return input_seq, target_seq
In [ ]:
class Task1(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        out, _ = self.lstm(embedded)
        out = self.fc(out)
        return out
In [ ]:
import torch
import matplotlib.pyplot as plt
from torch import nn

def train_model(model, train_loader, val_loader, preprocessor, epochs=10, lr=0.001, file_limit=100, base_path="symbolic_model"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Starting training on {device}...")

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    model.train()

    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        total_train_loss = 0
        model.train()

        for batch_idx, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output.view(-1, output.size(-1)), y.view(-1))
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

            if batch_idx % 500 == 0:
                print(f"   Batch {batch_idx}/{len(train_loader)}, Train Loss: {loss.item():.4f}")

        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for x_val, y_val in val_loader:
                x_val = x_val.to(device)
                y_val = y_val.to(device)

                output = model(x_val)
                val_loss = criterion(output.view(-1, output.size(-1)), y_val.view(-1))
                total_val_loss += val_loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        print(f"Epoch {epoch+1} complete. Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        torch.save({
            "model_state_dict": model.state_dict(),
            "sequence_length": preprocessor.sequence_length,
            "pitch_range": preprocessor.pitch_range,
        }, f"{base_path}_f{file_limit}_e{epoch+1}.pth")
        print(f"Model saved to {base_path}_f{file_limit}_e{epoch+1}.pth")

    # Plot losses
    plt.figure(figsize=(8, 5))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("loss_plot.png")
    plt.show()
    print("Loss plot saved as 'loss_plot.png'")
In [ ]:
def load_model(model_path, vocab_size):
    checkpoint = torch.load(model_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    model = Task1(vocab_size)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    return model, checkpoint["sequence_length"], checkpoint["pitch_range"]
In [ ]:
def generate_music(model, start_seq, length=200, temperature=1.0, seed=None):
    if seed is not None:
        torch.manual_seed(seed)

    model.eval()
    device = next(model.parameters()).device
    input_seq = torch.tensor(start_seq, dtype=torch.long).unsqueeze(0).to(device)
    generated = start_seq[:]

    with torch.no_grad():
        for _ in range(length):
            output = model(input_seq)[0]

            if output.dim() == 3:
                logits = output[0, -1, :]
            elif output.dim() == 2:
                logits = output[-1]      
            else:
                raise ValueError("Unexpected model output shape.")

            logits = logits / temperature
            probs = torch.softmax(logits, dim=0)
            next_note = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_note)

            input_seq = torch.tensor(generated[-len(start_seq):], dtype=torch.long).unsqueeze(0).to(device)

    return generated
In [ ]:
def save_midi(note_sequence, file_path, preprocessor, tempo=120):
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=0)
    start_time = 0.0
    step = 60.0 / tempo

    for pitch in note_sequence:
        note = pretty_midi.Note(
            velocity=100,
            pitch=pitch + preprocessor.pitch_range[0],
            start=start_time,
            end=start_time + step
        )
        instrument.notes.append(note)
        start_time += step

    pm.instruments.append(instrument)
    pm.write(file_path)
    print(f"MIDI saved: {file_path}")
In [ ]:
def play_midi(file_path):
    midi_data = pretty_midi.PrettyMIDI(file_path)
    audio_data = midi_data.synthesize()
    display(Audio(audio_data, rate=44100))
In [ ]:
from torch.utils.data import DataLoader, random_split

preprocessor = MIDIPreprocessor(sequence_length=50)
dataset = MIDIDataset("data/midis", preprocessor, file_limit=100)

val_size = int(0.1 * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
Preprocessor initialized: pitch range (21, 109), sequence length 50
Loading MIDI dataset from: data/midis
Building note sequences...
Built 336358 sequences
Dataset ready: 336358 sequences
In [ ]:
model = Task1(vocab_size=preprocessor.num_pitches)

train_model(model, train_loader, val_loader, preprocessor, epochs=20)
Starting training on cuda...

Epoch 1/20
   Batch 0/4731, Train Loss: 4.4866
   Batch 500/4731, Train Loss: 3.1311
   Batch 1000/4731, Train Loss: 2.8842
   Batch 1500/4731, Train Loss: 2.8125
   Batch 2000/4731, Train Loss: 2.6422
   Batch 2500/4731, Train Loss: 2.4669
   Batch 3000/4731, Train Loss: 2.3412
   Batch 3500/4731, Train Loss: 2.0884
   Batch 4000/4731, Train Loss: 2.0008
   Batch 4500/4731, Train Loss: 1.9386
Epoch 1 complete. Train Loss: 2.5786, Val Loss: 2.0064
Model saved to symbolic_model_f100_e1.pth

Epoch 2/20
   Batch 0/4731, Train Loss: 1.9593
   Batch 500/4731, Train Loss: 1.8270
   Batch 1000/4731, Train Loss: 1.7251
   Batch 1500/4731, Train Loss: 1.7805
   Batch 2000/4731, Train Loss: 1.6123
   Batch 2500/4731, Train Loss: 1.6750
   Batch 3000/4731, Train Loss: 1.6325
   Batch 3500/4731, Train Loss: 1.5674
   Batch 4000/4731, Train Loss: 1.5423
   Batch 4500/4731, Train Loss: 1.5208
Epoch 2 complete. Train Loss: 1.7081, Val Loss: 1.5302
Model saved to symbolic_model_f100_e2.pth

Epoch 3/20
   Batch 0/4731, Train Loss: 1.4920
   Batch 500/4731, Train Loss: 1.4557
   Batch 1000/4731, Train Loss: 1.4258
   Batch 1500/4731, Train Loss: 1.4187
   Batch 2000/4731, Train Loss: 1.3488
   Batch 2500/4731, Train Loss: 1.3993
   Batch 3000/4731, Train Loss: 1.3231
   Batch 3500/4731, Train Loss: 1.4511
   Batch 4000/4731, Train Loss: 1.3839
   Batch 4500/4731, Train Loss: 1.2585
Epoch 3 complete. Train Loss: 1.3851, Val Loss: 1.3328
Model saved to symbolic_model_f100_e3.pth

Epoch 4/20
   Batch 0/4731, Train Loss: 1.3088
   Batch 500/4731, Train Loss: 1.3380
   Batch 1000/4731, Train Loss: 1.3364
   Batch 1500/4731, Train Loss: 1.2599
   Batch 2000/4731, Train Loss: 1.2737
   Batch 2500/4731, Train Loss: 1.2240
   Batch 3000/4731, Train Loss: 1.2105
   Batch 3500/4731, Train Loss: 1.1997
   Batch 4000/4731, Train Loss: 1.2086
   Batch 4500/4731, Train Loss: 1.1723
Epoch 4 complete. Train Loss: 1.2356, Val Loss: 1.2218
Model saved to symbolic_model_f100_e4.pth

Epoch 5/20
   Batch 0/4731, Train Loss: 1.1673
   Batch 500/4731, Train Loss: 1.2181
   Batch 1000/4731, Train Loss: 1.1913
   Batch 1500/4731, Train Loss: 1.1767
   Batch 2000/4731, Train Loss: 1.2371
   Batch 2500/4731, Train Loss: 1.2301
   Batch 3000/4731, Train Loss: 1.0791
   Batch 3500/4731, Train Loss: 1.1235
   Batch 4000/4731, Train Loss: 1.1507
   Batch 4500/4731, Train Loss: 1.0776
Epoch 5 complete. Train Loss: 1.1516, Val Loss: 1.1596
Model saved to symbolic_model_f100_e5.pth

Epoch 6/20
   Batch 0/4731, Train Loss: 1.0774
   Batch 500/4731, Train Loss: 1.1313
   Batch 1000/4731, Train Loss: 1.1265
   Batch 1500/4731, Train Loss: 1.1412
   Batch 2000/4731, Train Loss: 1.0861
   Batch 2500/4731, Train Loss: 1.0883
   Batch 3000/4731, Train Loss: 1.1273
   Batch 3500/4731, Train Loss: 1.0519
   Batch 4000/4731, Train Loss: 1.0771
   Batch 4500/4731, Train Loss: 1.1455
Epoch 6 complete. Train Loss: 1.0974, Val Loss: 1.1170
Model saved to symbolic_model_f100_e6.pth

Epoch 7/20
   Batch 0/4731, Train Loss: 1.0939
   Batch 500/4731, Train Loss: 1.1123
   Batch 1000/4731, Train Loss: 1.0623
   Batch 1500/4731, Train Loss: 1.0700
   Batch 2000/4731, Train Loss: 1.0477
   Batch 2500/4731, Train Loss: 1.1031
   Batch 3000/4731, Train Loss: 0.9900
   Batch 3500/4731, Train Loss: 1.0420
   Batch 4000/4731, Train Loss: 1.0634
   Batch 4500/4731, Train Loss: 1.0290
Epoch 7 complete. Train Loss: 1.0581, Val Loss: 1.0869
Model saved to symbolic_model_f100_e7.pth

Epoch 8/20
   Batch 0/4731, Train Loss: 1.0379
   Batch 500/4731, Train Loss: 1.0057
   Batch 1000/4731, Train Loss: 1.0283
   Batch 1500/4731, Train Loss: 0.9872
   Batch 2000/4731, Train Loss: 1.0183
   Batch 2500/4731, Train Loss: 1.0183
   Batch 3000/4731, Train Loss: 1.0018
   Batch 3500/4731, Train Loss: 0.9894
   Batch 4000/4731, Train Loss: 1.0257
   Batch 4500/4731, Train Loss: 0.9834
Epoch 8 complete. Train Loss: 1.0287, Val Loss: 1.0570
Model saved to symbolic_model_f100_e8.pth

Epoch 9/20
   Batch 0/4731, Train Loss: 0.9979
   Batch 500/4731, Train Loss: 0.9744
   Batch 1000/4731, Train Loss: 0.9818
   Batch 1500/4731, Train Loss: 1.0085
   Batch 2000/4731, Train Loss: 1.0561
   Batch 2500/4731, Train Loss: 0.9815
   Batch 3000/4731, Train Loss: 1.0033
   Batch 3500/4731, Train Loss: 0.9868
   Batch 4000/4731, Train Loss: 0.9614
   Batch 4500/4731, Train Loss: 1.0136
Epoch 9 complete. Train Loss: 1.0043, Val Loss: 1.0344
Model saved to symbolic_model_f100_e9.pth

Epoch 10/20
   Batch 0/4731, Train Loss: 0.9829
   Batch 500/4731, Train Loss: 1.0075
   Batch 1000/4731, Train Loss: 0.9666
   Batch 1500/4731, Train Loss: 1.0069
   Batch 2000/4731, Train Loss: 0.9614
   Batch 2500/4731, Train Loss: 0.9906
   Batch 3000/4731, Train Loss: 0.9751
   Batch 3500/4731, Train Loss: 1.0184
   Batch 4000/4731, Train Loss: 0.9644
   Batch 4500/4731, Train Loss: 1.0216
Epoch 10 complete. Train Loss: 0.9842, Val Loss: 1.0179
Model saved to symbolic_model_f100_e10.pth

Epoch 11/20
   Batch 0/4731, Train Loss: 0.9614
   Batch 500/4731, Train Loss: 0.9420
   Batch 1000/4731, Train Loss: 0.9650
   Batch 1500/4731, Train Loss: 0.9655
   Batch 2000/4731, Train Loss: 0.9812
   Batch 2500/4731, Train Loss: 0.9396
   Batch 3000/4731, Train Loss: 0.9947
   Batch 3500/4731, Train Loss: 1.0044
   Batch 4000/4731, Train Loss: 0.9412
   Batch 4500/4731, Train Loss: 0.9848
Epoch 11 complete. Train Loss: 0.9671, Val Loss: 1.0031
Model saved to symbolic_model_f100_e11.pth

Epoch 12/20
   Batch 0/4731, Train Loss: 0.9281
   Batch 500/4731, Train Loss: 0.9982
   Batch 1000/4731, Train Loss: 0.9070
   Batch 1500/4731, Train Loss: 1.0118
   Batch 2000/4731, Train Loss: 0.9661
   Batch 2500/4731, Train Loss: 0.9760
   Batch 3000/4731, Train Loss: 0.9840
   Batch 3500/4731, Train Loss: 0.9229
   Batch 4000/4731, Train Loss: 0.9543
   Batch 4500/4731, Train Loss: 0.9869
Epoch 12 complete. Train Loss: 0.9520, Val Loss: 0.9873
Model saved to symbolic_model_f100_e12.pth

Epoch 13/20
   Batch 0/4731, Train Loss: 0.9392
   Batch 500/4731, Train Loss: 0.8681
   Batch 1000/4731, Train Loss: 0.9210
   Batch 1500/4731, Train Loss: 0.9989
   Batch 2000/4731, Train Loss: 1.0534
   Batch 2500/4731, Train Loss: 0.8954
   Batch 3000/4731, Train Loss: 0.9963
   Batch 3500/4731, Train Loss: 0.9660
   Batch 4000/4731, Train Loss: 0.9737
   Batch 4500/4731, Train Loss: 0.9184
Epoch 13 complete. Train Loss: 0.9386, Val Loss: 0.9761
Model saved to symbolic_model_f100_e13.pth

Epoch 14/20
   Batch 0/4731, Train Loss: 0.8807
   Batch 500/4731, Train Loss: 0.9447
   Batch 1000/4731, Train Loss: 0.9628
   Batch 1500/4731, Train Loss: 0.9209
   Batch 2000/4731, Train Loss: 0.9491
   Batch 2500/4731, Train Loss: 0.9531
   Batch 3000/4731, Train Loss: 0.9424
   Batch 3500/4731, Train Loss: 0.9481
   Batch 4000/4731, Train Loss: 0.9102
   Batch 4500/4731, Train Loss: 0.9232
Epoch 14 complete. Train Loss: 0.9268, Val Loss: 0.9665
Model saved to symbolic_model_f100_e14.pth

Epoch 15/20
   Batch 0/4731, Train Loss: 0.8969
   Batch 500/4731, Train Loss: 0.9352
   Batch 1000/4731, Train Loss: 0.8824
   Batch 1500/4731, Train Loss: 0.8754
   Batch 2000/4731, Train Loss: 0.8956
   Batch 2500/4731, Train Loss: 0.9022
   Batch 3000/4731, Train Loss: 0.8872
   Batch 3500/4731, Train Loss: 0.9434
   Batch 4000/4731, Train Loss: 0.8789
   Batch 4500/4731, Train Loss: 0.8746
Epoch 15 complete. Train Loss: 0.9161, Val Loss: 0.9536
Model saved to symbolic_model_f100_e15.pth

Epoch 16/20
   Batch 0/4731, Train Loss: 0.8407
   Batch 500/4731, Train Loss: 0.9048
   Batch 1000/4731, Train Loss: 0.9198
   Batch 1500/4731, Train Loss: 0.8823
   Batch 2000/4731, Train Loss: 0.9031
   Batch 2500/4731, Train Loss: 0.9427
   Batch 3000/4731, Train Loss: 0.9086
   Batch 3500/4731, Train Loss: 0.8719
   Batch 4000/4731, Train Loss: 0.8657
   Batch 4500/4731, Train Loss: 0.9103
Epoch 16 complete. Train Loss: 0.9061, Val Loss: 0.9461
Model saved to symbolic_model_f100_e16.pth

Epoch 17/20
   Batch 0/4731, Train Loss: 0.8703
   Batch 500/4731, Train Loss: 0.9346
   Batch 1000/4731, Train Loss: 0.8435
   Batch 1500/4731, Train Loss: 0.8955
   Batch 2000/4731, Train Loss: 0.8907
   Batch 2500/4731, Train Loss: 0.8927
   Batch 3000/4731, Train Loss: 0.9145
   Batch 3500/4731, Train Loss: 0.9361
   Batch 4000/4731, Train Loss: 0.8999
   Batch 4500/4731, Train Loss: 0.9123
Epoch 17 complete. Train Loss: 0.8966, Val Loss: 0.9379
Model saved to symbolic_model_f100_e17.pth

Epoch 18/20
   Batch 0/4731, Train Loss: 0.8617
   Batch 500/4731, Train Loss: 0.8669
   Batch 1000/4731, Train Loss: 0.8392
   Batch 1500/4731, Train Loss: 0.8924
   Batch 2000/4731, Train Loss: 0.9068
   Batch 2500/4731, Train Loss: 0.8603
   Batch 3000/4731, Train Loss: 0.8543
   Batch 3500/4731, Train Loss: 0.8754
   Batch 4000/4731, Train Loss: 0.8908
   Batch 4500/4731, Train Loss: 0.8644
Epoch 18 complete. Train Loss: 0.8885, Val Loss: 0.9289
Model saved to symbolic_model_f100_e18.pth

Epoch 19/20
   Batch 0/4731, Train Loss: 0.8545
   Batch 500/4731, Train Loss: 0.8145
   Batch 1000/4731, Train Loss: 0.9053
   Batch 1500/4731, Train Loss: 0.9408
   Batch 2000/4731, Train Loss: 0.8843
   Batch 2500/4731, Train Loss: 0.8901
   Batch 3000/4731, Train Loss: 0.9212
   Batch 3500/4731, Train Loss: 0.9347
   Batch 4000/4731, Train Loss: 0.8735
   Batch 4500/4731, Train Loss: 0.8816
Epoch 19 complete. Train Loss: 0.8801, Val Loss: 0.9244
Model saved to symbolic_model_f100_e19.pth

Epoch 20/20
   Batch 0/4731, Train Loss: 0.9042
   Batch 500/4731, Train Loss: 0.8567
   Batch 1000/4731, Train Loss: 0.8567
   Batch 1500/4731, Train Loss: 0.8742
   Batch 2000/4731, Train Loss: 0.9002
   Batch 2500/4731, Train Loss: 0.9080
   Batch 3000/4731, Train Loss: 0.8216
   Batch 3500/4731, Train Loss: 0.8231
   Batch 4000/4731, Train Loss: 0.8374
   Batch 4500/4731, Train Loss: 0.8503
Epoch 20 complete. Train Loss: 0.8727, Val Loss: 0.9160
Model saved to symbolic_model_f100_e20.pth
Loss plot saved as 'loss_plot.png'
In [ ]:
def midi_generation(model_path, output_path="generated_random.mid", length=200, temperature=1.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    preprocessor = MIDIPreprocessor(sequence_length=50)
    vocab_size = preprocessor.num_pitches

    checkpoint = torch.load(model_path, map_location=device)
    model = Task1(vocab_size)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)
    model.eval()

    start_seq = [random.randint(0, vocab_size - 1) for _ in range(preprocessor.sequence_length)]

    generated_notes = generate_music(model, start_seq, length=length, temperature=temperature, seed=None)

    save_midi(generated_notes, output_path, preprocessor)

midi_generation("symbolic_model_f100_e20.pth", output_path="trial4.mid", length=300, temperature=1.0)
Preprocessor initialized: pitch range (21, 109), sequence length 50
MIDI saved: trial4.mid

Modeling + Evaluation: Task2: Code Walkthrough (Symbolic, conditional generation)¶

Baseline Model: Task 2 (Symbolic, conditional generation)¶

In [ ]:
# !pip install music21
In [ ]:
from glob import glob
import random
import numpy as np
import pandas as pd
from numpy.random import choice
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from miditok import REMI, TokenizerConfig
from mido import Message, MidiFile, MidiTrack, MetaMessage, bpm2tempo
from music21 import midi, chord, note
In [ ]:
device = torch.device("cuda")
In [ ]:
midi_files = glob('data/midis/*.mid')
len(midi_files)
Out[ ]:
10854
In [ ]:
# Set the random seed
random.seed(42)

dataroot = '/data/midis'
sample_files = random.sample(midi_files, 1500)
print(sample_files[0])
data/midis/Ladurner, Ignace Antoine, 3 Keyboard Sonatas, Op.4, fzYLT5JMtZk.mid
In [ ]:
def midi_preprocess(midi_file, max_len = 30):
    midi = MidiFile(midi_file)
    note_times = {}
    melodies = []
    harmonies = []

    end_time = 30
    current_time = 0

    for msg in midi.play():
        current_time += msg.time
        if current_time >= end_time:
            break
        if msg.type == 'note_on' and msg.velocity > 0:
            timestamp = msg.time
            if timestamp not in note_times:
                note_times[timestamp] = []
            note_times[timestamp].append(msg.note)

    for timestamp, notes in note_times.items():
        if len(melodies) >= max_len:
            break
        melodies.append(notes[0])
        harmonies.append(chord.Chord(notes).commonName)

    if len(melodies) < max_len:
        melodies += [0] * (max_len - len(melodies)) 
        harmonies += ["Rest"] * (max_len - len(harmonies))

    return melodies, harmonies
In [ ]:
melodies = []
harmonies = []

for i in range(10):
    print(i)
    midi_file = sample_files[i]
    melody, harmony = midi_preprocess(midi_file)
    melodies.append(melody)
    harmonies.append(harmony)
0
1
2
3
4
5
6
7
8
9
In [ ]:
note_set = sorted(set([note for melody in melodies for note in melody]))
chord_set = sorted(set([chord for harmony in harmonies for chord in harmony]))

note_to_int = {note: i for i, note in enumerate(note_set)}
chord_to_int = {chord: i for i, chord in enumerate(chord_set)}

X_train = [[note_to_int[n] for n in melody] for melody in melodies]
y_train = [[chord_to_int[c] for c in harmony] for harmony in harmonies]

X_train = np.array(X_train)
y_train = np.array(y_train)
In [ ]:
print(X_train)
print(y_train)
[[34 32 15 31 29 19 27 27 22 19 15 29 22 31 32 34 36 22 39 15 20 20 15 38
  36 24 20 36 24 34]
 [29  7 20 12 24 25 17 30 24 32 30 25 34 24 24 25 24 25 36  8 13 25 27 34
  27 30 12 24 25 19]
 [39  6 12 19  6  5 12 38 15 12 19 41 22 12 31  7 21 33 38 14 33 36 34 26
  34  5 19 24  2 26]
 [20 20 20 20 20 20 32 32 32 32 32 32 22 35  4 15 27 20 30 24 18 21 20 32
   1 20 23 18 10 18]
 [ 3  8 11 22 13 16 18 16 18 22 25 28 30 18 34 37 34 40 32 16  8 13 30 43
  25 34 18 25 42 37]
 [32 34 18 24 27 20  9 18 27 32 34 17 25 29 29 25 32 29 37  9  9 15 32 36
  37 41 12 39 36 32]
 [24  7 29 22  6 24  7  6 36 24 31 38 32 36 31 32 26 27 34 22 29 36 32 31
   5 29 40 28 26 38]
 [15 18 22 30 33 23 15 20 17 23 43 38 39 41 39 22 42 34 17 35 39 32 30 18
  15  6 34 27  7 22]
 [32 20 32 20 32 20 32 32 20  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
   0  0  0  0  0  0]
 [20 22 24 25 27 25 24 29 22 27 29 27 25 24 27 24 22 20 29 22 20 19 32 17
  34 32 36 29 34 26]]
[[20 64 48 64 17 64 18  4 45 64 64 21  2 62 11 28 36 70 63 65 74 64 37 46
  45 64 64 64 19 66]
 [64 64 50 70 64 64 48 74 72 64 25 56 64  9 64 64 69 69  7 64 68 64 64 64
  44 64 64 64 64 64]
 [64 30  1 75 64 50  1 64 64 31 67 64 64 69 64 64 64 64 64 64 64 64 64 64
  64 64 64 64 64 64]
 [64 59 30 18 13 22 64 64 64 64 64 64 64 64 64 64 64 64 64 64 64  0 64 64
   8 64 13 64  7 64]
 [64 64  0 64 42 64 64 25 72 61 64  3 17 44 46 64 61 50 64 64 64 64 64 64
  50 64 64 64 64 46]
 [64 64 64 38 46 49 64 64 34 64 64 64 16 58 55 50 64 68  0  2 64 64  3 64
  64 64 74 64 64 15]
 [64 35 43 64 41 32 62  1 64 39 50 64 64 64 34 26 64 64 64 17 64 50 29 64
  64 52 64  4 64 64]
 [64 40 57  6  9 30 27 24 47  1 64 53  1 64 64 57 70  9 60  9 50 13 64  5
  24 50 13 18 10 64]
 [64 64 64 20 64 64 20 64 64 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23
  23 23 23 23 23 23]
 [64 44 64 74 64 73 64 48 64 33 64 14 64 76  6 54 64 64 38 64 17 64 12 70
  64  6 64 64 51 71]]
In [ ]:
x_df = pd.DataFrame(X_train)
y_df = pd.DataFrame(y_train)

x_df.to_csv("training_x_test.csv", index=False)
y_df.to_csv("training_y_test.csv", index=False)

print("Single CSV file saved successfully!")
Single CSV file saved successfully!
In [ ]:
df_x = pd.read_csv("training_x_test.csv")
df_y = pd.read_csv("training_y_test.csv")

X_train = df_x.values
y_train = df_y.values

print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
X_train shape: (10, 30)
y_train shape: (10, 30)
In [ ]:
class LSTM(nn.Module):
    def __init__(self, input_size, sequence_length, num_classes):
        super(LSTM, self).__init__()
        self.embedding = nn.Embedding(input_size, 64)
        self.lstm = nn.LSTM(64, 128, batch_first=True)
        self.fc1 = nn.Linear(128, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return self.softmax(x)
In [ ]:
model = LSTM(input_size=len(note_set), sequence_length=X_train.shape[1], num_classes=len(chord_set))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
In [ ]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Embedding

model = Sequential([
    Embedding(len(note_set), 64, input_length=X_train.shape[1]),
    LSTM(128, return_sequences=True),
    Dense(64, activation='relu'),
    Dense(len(chord_set), activation='softmax')
])

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
losses = model.fit(X_train, y_train, epochs=50, batch_size=32)
2025-06-03 18:11:00.382468: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-03 18:11:00.382527: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-03 18:11:00.383859: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-03 18:11:00.390558: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-06-03 18:11:02.252067: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-06-03 18:11:02.254867: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-06-03 18:11:02.255419: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-06-03 18:11:02.262151: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-06-03 18:11:02.262703: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-06-03 18:11:02.263233: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-06-03 18:11:02.384097: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-06-03 18:11:02.384650: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-06-03 18:11:02.385168: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-06-03 18:11:02.385634: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10534 MB memory:  -> device: 0, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:c4:00.0, compute capability: 6.1
Epoch 1/50
2025-06-03 18:11:04.457748: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8902
2025-06-03 18:11:04.686499: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f6ff8241f30 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-06-03 18:11:04.686559: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce GTX 1080 Ti, Compute Capability 6.1
2025-06-03 18:11:04.698810: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748974264.826799     689 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
1/1 [==============================] - 3s 3s/step - loss: 4.3424 - accuracy: 0.0067
Epoch 2/50
1/1 [==============================] - 0s 8ms/step - loss: 4.3338 - accuracy: 0.3900
Epoch 3/50
1/1 [==============================] - 0s 7ms/step - loss: 4.3238 - accuracy: 0.5300
Epoch 4/50
1/1 [==============================] - 0s 9ms/step - loss: 4.3119 - accuracy: 0.5300
Epoch 5/50
1/1 [==============================] - 0s 8ms/step - loss: 4.2974 - accuracy: 0.5300
Epoch 6/50
1/1 [==============================] - 0s 9ms/step - loss: 4.2792 - accuracy: 0.5267
Epoch 7/50
1/1 [==============================] - 0s 9ms/step - loss: 4.2554 - accuracy: 0.5167
Epoch 8/50
1/1 [==============================] - 0s 9ms/step - loss: 4.2231 - accuracy: 0.4800
Epoch 9/50
1/1 [==============================] - 0s 9ms/step - loss: 4.1776 - accuracy: 0.4800
Epoch 10/50
1/1 [==============================] - 0s 8ms/step - loss: 4.1101 - accuracy: 0.4800
Epoch 11/50
1/1 [==============================] - 0s 10ms/step - loss: 4.0057 - accuracy: 0.4800
Epoch 12/50
1/1 [==============================] - 0s 9ms/step - loss: 3.8416 - accuracy: 0.4800
Epoch 13/50
1/1 [==============================] - 0s 8ms/step - loss: 3.6121 - accuracy: 0.4800
Epoch 14/50
1/1 [==============================] - 0s 8ms/step - loss: 3.4147 - accuracy: 0.4800
Epoch 15/50
1/1 [==============================] - 0s 8ms/step - loss: 3.3813 - accuracy: 0.4800
Epoch 16/50
1/1 [==============================] - 0s 8ms/step - loss: 3.4057 - accuracy: 0.4800
Epoch 17/50
1/1 [==============================] - 0s 8ms/step - loss: 3.3897 - accuracy: 0.4800
Epoch 18/50
1/1 [==============================] - 0s 9ms/step - loss: 3.3321 - accuracy: 0.4800
Epoch 19/50
1/1 [==============================] - 0s 8ms/step - loss: 3.2537 - accuracy: 0.4800
Epoch 20/50
1/1 [==============================] - 0s 8ms/step - loss: 3.1750 - accuracy: 0.4800
Epoch 21/50
1/1 [==============================] - 0s 9ms/step - loss: 3.1129 - accuracy: 0.4800
Epoch 22/50
1/1 [==============================] - 0s 9ms/step - loss: 3.0763 - accuracy: 0.4800
Epoch 23/50
1/1 [==============================] - 0s 8ms/step - loss: 3.0618 - accuracy: 0.4800
Epoch 24/50
1/1 [==============================] - 0s 8ms/step - loss: 3.0584 - accuracy: 0.4800
Epoch 25/50
1/1 [==============================] - 0s 9ms/step - loss: 3.0555 - accuracy: 0.4800
Epoch 26/50
1/1 [==============================] - 0s 8ms/step - loss: 3.0461 - accuracy: 0.4800
Epoch 27/50
1/1 [==============================] - 0s 9ms/step - loss: 3.0293 - accuracy: 0.4800
Epoch 28/50
1/1 [==============================] - 0s 9ms/step - loss: 3.0069 - accuracy: 0.4800
Epoch 29/50
1/1 [==============================] - 0s 8ms/step - loss: 2.9822 - accuracy: 0.4800
Epoch 30/50
1/1 [==============================] - 0s 8ms/step - loss: 2.9582 - accuracy: 0.4800
Epoch 31/50
1/1 [==============================] - 0s 9ms/step - loss: 2.9392 - accuracy: 0.4800
Epoch 32/50
1/1 [==============================] - 0s 9ms/step - loss: 2.9275 - accuracy: 0.4800
Epoch 33/50
1/1 [==============================] - 0s 8ms/step - loss: 2.9217 - accuracy: 0.4800
Epoch 34/50
1/1 [==============================] - 0s 8ms/step - loss: 2.9180 - accuracy: 0.4800
Epoch 35/50
1/1 [==============================] - 0s 8ms/step - loss: 2.9137 - accuracy: 0.4800
Epoch 36/50
1/1 [==============================] - 0s 8ms/step - loss: 2.9069 - accuracy: 0.4800
Epoch 37/50
1/1 [==============================] - 0s 8ms/step - loss: 2.8970 - accuracy: 0.4800
Epoch 38/50
1/1 [==============================] - 0s 9ms/step - loss: 2.8853 - accuracy: 0.4800
Epoch 39/50
1/1 [==============================] - 0s 9ms/step - loss: 2.8733 - accuracy: 0.4800
Epoch 40/50
1/1 [==============================] - 0s 9ms/step - loss: 2.8625 - accuracy: 0.4800
Epoch 41/50
1/1 [==============================] - 0s 8ms/step - loss: 2.8539 - accuracy: 0.4800
Epoch 42/50
1/1 [==============================] - 0s 8ms/step - loss: 2.8474 - accuracy: 0.4800
Epoch 43/50
1/1 [==============================] - 0s 8ms/step - loss: 2.8416 - accuracy: 0.4800
Epoch 44/50
1/1 [==============================] - 0s 8ms/step - loss: 2.8353 - accuracy: 0.4800
Epoch 45/50
1/1 [==============================] - 0s 8ms/step - loss: 2.8275 - accuracy: 0.4800
Epoch 46/50
1/1 [==============================] - 0s 8ms/step - loss: 2.8180 - accuracy: 0.4800
Epoch 47/50
1/1 [==============================] - 0s 8ms/step - loss: 2.8073 - accuracy: 0.4800
Epoch 48/50
1/1 [==============================] - 0s 9ms/step - loss: 2.7963 - accuracy: 0.4800
Epoch 49/50
1/1 [==============================] - 0s 8ms/step - loss: 2.7859 - accuracy: 0.4800
Epoch 50/50
1/1 [==============================] - 0s 8ms/step - loss: 2.7757 - accuracy: 0.4800
In [ ]:
loss = losses.history['loss']

plt.plot(range(50), loss)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss vs. Epochs')
plt.show()
In [ ]:
def get_melody(midi_file):
    midi = MidiFile(midi_file)
    melodies = []

    end_time = 30
    current_time = 0

    for msg in midi.play():
        current_time += msg.time
        if current_time >= end_time:
            break
        if msg.type == 'note_on' and msg.velocity > 0:
            melodies.append(msg.note)

    return melodies
In [ ]:
# Test file
midi_file = sample_files[1001]
melody = get_melody(midi_file)[:30]

print(midi_file)
print(melody)
data/midis/Joplin, Scott, A Breeze from Alabama, nXe43xnOEf4.mid
[60, 55, 64, 67, 48, 60, 72, 63, 60, 57, 54, 57, 69, 68, 56, 69, 57, 71, 59, 72, 55, 60, 64, 55, 67, 52, 64, 63, 51, 64]
In [ ]:
def generate_harmony(melody):
    nums = [note_to_int[n] if n in note_to_int.keys() else 0 for n in melody]
    nums = np.array([nums])

    prediction = model.predict(nums)
    harmony_predicted = [chord_set[np.argmax(p)] for p in prediction[0]]

    return harmony_predicted

predicted_harmony = generate_harmony(melody)
print(predicted_harmony)
1/1 [==============================] - 0s 381ms/step
['note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note']
In [ ]:
def harmony_midi(harmony):
    midi_notes = []
    for h in harmony:
        try:
            chord_obj = chord.Chord(h)
            midi_notes.append(chord_obj.pitches[0].midi)
        except:
            try:
                midi_notes.append(note.Note(h).pitch.midi)
            except:
                midi_notes.append(60)

    return midi_notes
In [ ]:
midi_harmony = harmony_midi(predicted_harmony)
print(midi_harmony)
[60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60]
In [ ]:
def generate_midi(melody, harmony, output_file="generated_harmony.mid"):
    midi_new = MidiFile()
    melody_track = MidiTrack()
    harmony_track = MidiTrack()

    for m in melody:
        melody_track.append(Message('note_on', note=m, velocity=64, time=200))
        melody_track.append(Message('note_off', note=m, velocity=64, time=200))

    for h in harmony:
        harmony_track.append(Message('note_on', note=h, velocity=64, time=200))
        harmony_track.append(Message('note_off', note=h, velocity=64, time=200))

    midi_new.tracks.append(melody_track)
    midi_new.tracks.append(harmony_track)

    midi_new.save(output_file)
    print(f"Saved generated harmony to {output_file}")
In [ ]:
generate_midi(melody, midi_harmony)
Saved generated harmony to generated_harmony.mid
In [ ]:
def play_midi(path):
    mf = midi.MidiFile()
    mf.open(path)
    mf.read()
    mf.close()
    s = midi.translate.midiFileToStream(mf)
    s.show('midi')

play_midi("generated_harmony.mid")

Task2: Main Model: VAE¶

In [ ]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from music21 import converter, note, chord, stream, roman
In [ ]:
class Config:
    max_seq_len = 512
    latent_dim = 256
    d_model = 384
    n_heads = 6
    n_layers = 4
    dropout = 0.1
    batch_size = 4
    lr = 1e-4
    grad_clip = 1.0
    kl_anneal_epochs = 20
    eps = 1e-8
    temp = 1.2
    top_k = 20
In [ ]:
def build_vocab():
    vocab = {"PAD": 0, "BOS": 1}
    idx = 2
    for p in range(21, 109):
        for d in [0.25, 0.5, 1.0, 2.0]:
            vocab[f"Note_{p}_{d}"] = idx
            idx += 1
    for fig in ["I", "ii", "iii", "IV", "V", "vi", "vii", "V7", "ii7", "I6", "V/V"]:
        vocab[f"Roman_{fig}"] = idx
        idx += 1
    return vocab

vocab = build_vocab()
inv_vocab = {v: k for k, v in vocab.items()}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
In [ ]:
class Music21ConditionalDataset(Dataset):
    def __init__(self, midi_dir, max_seq_len=512):
        self.paths = list(Path(midi_dir).glob("*.mid"))[:10]
        self.max_seq_len = max_seq_len
        self.pad_id = vocab["PAD"]
        self.bos_id = vocab["BOS"]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        try:
            score = converter.parse(str(self.paths[idx]))
            melody, harmony = self.extract_tokens(score)
            melody = [self.bos_id] + melody[:self.max_seq_len - 1]
            harmony = [self.bos_id] + harmony[:self.max_seq_len - 1]

            melody += [self.pad_id] * (self.max_seq_len - len(melody))
            harmony += [self.pad_id] * (self.max_seq_len - len(harmony))
            return torch.LongTensor(melody), torch.LongTensor(harmony)
        except:
            dummy = [self.bos_id] + [self.pad_id] * (self.max_seq_len - 1)
            return torch.LongTensor(dummy), torch.LongTensor(dummy)

    def extract_tokens(self, score):
        melody_tokens, chord_tokens = [], []

        melody_part = score.parts[0]  # Assume melody is in first part
        key = score.analyze('key')

        for el in melody_part.recurse().notesAndRests:
            dur = round(el.duration.quarterLength, 2)
            dur = min([0.25, 0.5, 1.0, 2.0], key=lambda x: abs(x - dur))
            if isinstance(el, note.Note):
                pitch = el.pitch.midi
                tok = f"Note_{pitch}_{dur}"
                melody_tokens.append(vocab.get(tok, self.pad_id))
            else:
                melody_tokens.append(self.pad_id)

        harmony = score.chordify()
        for el in harmony.flat.getElementsByClass(chord.Chord):
            try:
                rn = roman.romanNumeralFromChord(el, key)
                tok = f"Roman_{rn.figure}"
                chord_tokens.append(vocab.get(tok, self.pad_id))
            except:
                chord_tokens.append(self.pad_id)

        return melody_tokens, chord_tokens
In [ ]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]
In [ ]:
class MusicVAE(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, Config.d_model)
        nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
        self.pos_enc = PositionalEncoding(Config.d_model)

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(Config.d_model, Config.n_heads, Config.d_model * 2, Config.dropout, batch_first=True, activation='gelu'),
            num_layers=Config.n_layers
        )
        self._init_weights(self.encoder)
        self.fc_mu = nn.Linear(Config.d_model, Config.latent_dim)
        self.fc_logvar = nn.Linear(Config.d_model, Config.latent_dim)
        nn.init.xavier_normal_(self.fc_mu.weight)
        nn.init.xavier_normal_(self.fc_logvar.weight)
        self.latent_proj = nn.Linear(Config.latent_dim, Config.d_model)

        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(Config.d_model, Config.n_heads, Config.d_model * 2, Config.dropout, batch_first=True, activation='gelu'),
            num_layers=Config.n_layers
        )
        self._init_weights(self.decoder)
        self.fc_out = nn.Linear(Config.d_model, vocab_size)
        nn.init.zeros_(self.fc_out.bias)

    def _init_weights(self, module):
        for p in module.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p)

    def encode(self, melody):
        src = self.embed(melody) * math.sqrt(Config.d_model)
        src = self.pos_enc(src)
        memory = self.encoder(src)
        mu = self.fc_mu(memory.mean(1))
        logvar = self.fc_logvar(memory.mean(1))
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar) + Config.eps
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, cond, target):
        memory = self.encoder(self.pos_enc(self.embed(cond) * math.sqrt(Config.d_model)))
        z_context = self.latent_proj(z).unsqueeze(1).repeat(1, memory.size(1), 1)
        memory = memory + z_context
        tgt = self.embed(target) * math.sqrt(Config.d_model)
        tgt = self.pos_enc(tgt)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(target.size(1), device=target.device)
        out = self.decoder(tgt=tgt, memory=memory, tgt_mask=tgt_mask)
        return self.fc_out(out)

    def forward(self, cond, target):
        mu, logvar = self.encode(cond)
        z = self.reparameterize(mu, logvar)
        logits = self.decode(z, cond, target[:, :-1])
        return logits, mu, logvar
In [ ]:
losses = []

def train_model(midi_dir, epochs=50):
    dataset = Music21ConditionalDataset(midi_dir)
    loader = DataLoader(dataset, batch_size=Config.batch_size, shuffle=False)

    model = MusicVAE(len(vocab)).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=Config.lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        kl_weight = min(1.0, epoch / Config.kl_anneal_epochs)
        for melody, harmony in loader:
            melody, harmony = melody.to(device), harmony.to(device)
            logits, mu, logvar = model(melody, harmony)
            loss_rec = F.cross_entropy(logits.view(-1, logits.size(-1)), harmony[:, 1:].contiguous().view(-1), ignore_index=vocab["PAD"])
            kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            loss = loss_rec + kl_weight * kl

            opt.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), Config.grad_clip)
            opt.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")
        losses.append(total_loss/len(loader))
        torch.save(model.state_dict(), "musicvae_music21.pth")
In [ ]:
def generate(model_path, melody_path, out_path="output.mid", max_tokens=256):
    model = MusicVAE(len(vocab)).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    dataset = Music21ConditionalDataset("")
    melody_tokens, _ = dataset.extract_tokens(converter.parse(melody_path))
    melody = [vocab["BOS"]] + melody_tokens[:Config.max_seq_len - 1]
    melody += [vocab["PAD"]] * (Config.max_seq_len - len(melody))
    melody = torch.LongTensor(melody).unsqueeze(0).to(device)

    with torch.no_grad():
        mu, logvar = model.encode(melody)
        z = model.reparameterize(mu, logvar)
        generated = [vocab["BOS"]]
        for _ in range(max_tokens):
            inp = torch.LongTensor([generated[-Config.max_seq_len + 1:]]).to(device)
            out = model.decode(z, melody, inp)[0, -1]
            probs = F.softmax(out / Config.temp, dim=0)
            topk_probs, topk_idx = probs.topk(Config.top_k)
            next_token = topk_idx[torch.multinomial(topk_probs, 1)].item()
            if next_token == vocab["PAD"]:
                break
            generated.append(next_token)

    # Reconstruct score with both melody and chords
    melody_stream = stream.Part()
    for tok_id in melody.squeeze().tolist():
        tok = inv_vocab.get(tok_id, "PAD")
        if tok.startswith("Note_"):
            _, pitch, dur = tok.split("_")
            n = note.Note(int(pitch), quarterLength=float(dur))
            melody_stream.append(n)

    harmony_stream = stream.Part()
    for tok_id in generated:
        tok = inv_vocab.get(tok_id, "PAD")
        if tok.startswith("Roman_"):
            fig = tok.replace("Roman_", "")
            try:
                rn = roman.RomanNumeral(fig, "C")  # placeholder key
                c = rn.pitchedCommonName.split()
                chord_obj = chord.Chord([str(p) for p in rn.pitches])
                chord_obj.quarterLength = 1.0
                harmony_stream.append(chord_obj)
            except:
                pass

    full_score = stream.Score()
    full_score.insert(0, melody_stream)
    full_score.insert(0, harmony_stream)
    full_score.write("midi", fp=out_path)
In [ ]:
train_model("data/midis")
/home/vsinha/.local/lib/python3.11/site-packages/music21/stream/base.py:3675: Music21DeprecationWarning: .flat is deprecated.  Call .flatten() instead
  return self.iter().getElementsByClass(classFilterList)
Epoch 1, Loss: 4.9969
Epoch 2, Loss: 3.6185
Epoch 3, Loss: 3.2209
Epoch 4, Loss: 3.0361
Epoch 5, Loss: 2.8989
Epoch 6, Loss: 2.7716
Epoch 7, Loss: 2.6814
Epoch 8, Loss: 2.6512
Epoch 9, Loss: 2.5464
Epoch 10, Loss: 2.5089
Epoch 11, Loss: 2.4878
Epoch 12, Loss: 2.3944
Epoch 13, Loss: 2.3354
Epoch 14, Loss: 2.3007
Epoch 15, Loss: 2.2703
Epoch 16, Loss: 2.1886
Epoch 17, Loss: 2.1908
Epoch 18, Loss: 2.1190
Epoch 19, Loss: 2.0822
Epoch 20, Loss: 2.0290
Epoch 21, Loss: 2.0261
Epoch 22, Loss: 2.0352
Epoch 23, Loss: 1.9049
Epoch 24, Loss: 1.8353
Epoch 25, Loss: 1.8090
Epoch 26, Loss: 1.7589
Epoch 27, Loss: 1.7259
Epoch 28, Loss: 1.6826
Epoch 29, Loss: 1.6491
Epoch 30, Loss: 1.6276
Epoch 31, Loss: 1.6025
Epoch 32, Loss: 1.5488
Epoch 33, Loss: 1.5583
Epoch 34, Loss: 1.5575
Epoch 35, Loss: 1.4880
Epoch 36, Loss: 1.4543
Epoch 37, Loss: 1.4462
Epoch 38, Loss: 1.4125
Epoch 39, Loss: 1.3815
Epoch 40, Loss: 1.3378
Epoch 41, Loss: 1.3246
Epoch 42, Loss: 1.3538
Epoch 43, Loss: 1.2830
Epoch 44, Loss: 1.3144
Epoch 45, Loss: 1.2397
Epoch 46, Loss: 1.3221
Epoch 47, Loss: 1.2183
Epoch 48, Loss: 1.2184
Epoch 49, Loss: 1.1903
Epoch 50, Loss: 1.1305
In [ ]:
import matplotlib.pyplot as plt

# Loss values you provided
loss_values = losses
# Corresponding epoch numbers
epochs = list(range(1, len(loss_values) + 1))

# Plotting
plt.figure(figsize=(10, 5))
plt.plot(epochs, loss_values, linestyle='-', color='royalblue', label='Training Loss')
plt.title('Training Loss vs Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend()
plt.tight_layout()
plt.show()
In [ ]:
generate("musicvae_music21.pth", "data/midis/Violette, Andrew, 3 Little Pieces, O2KofziUWQU.mid")
In [ ]:
import pretty_midi
import numpy as np
from IPython.display import Audio

def midi_to_audio(midi_path, sample_rate=44100):

    pm = pretty_midi.PrettyMIDI(midi_path)   
    audio = pm.synthesize(fs=sample_rate)
    
    return Audio(audio, rate=sample_rate)

midi_to_audio("output.mid")
Out[ ]:
Your browser does not support the audio element.
In [ ]:
import pretty_midi
import numpy as np
from IPython.display import Audio

def midi_to_audio(midi_path, sample_rate=44100):

    pm = pretty_midi.PrettyMIDI(midi_path)   
    audio = pm.synthesize(fs=sample_rate)
    
    return Audio(audio, rate=sample_rate)

midi_to_audio("data/midis/Violette, Andrew, 3 Little Pieces, O2KofziUWQU.mid")
Out[ ]:
Your browser does not support the audio element.